# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import collections
from typing import Union

import numpy as np

# normal dtype related
from .._imperative_rt import bfloat16, intb1, intb2, intb4


def is_lowbit(dtype):
    return (dtype is intb1) or (dtype is intb2) or (dtype is intb4)


def is_bfloat16(dtype):
    return dtype is bfloat16


# quantization dtype related
_QuantDtypeMetadata = collections.namedtuple(
    "QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",]
)

_metadata_dict = {
    "quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255),
    "qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127),
    "quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15),
    "qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7),
    "qint32": _QuantDtypeMetadata(
        "QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1,
    ),
    # NOTE: int2 is not supported for model dump yet
    "quint2": _QuantDtypeMetadata(None, "uint8", True, 0, 3),
    "qint2": _QuantDtypeMetadata(None, "int8", False, -2, 1),
}


def is_quantize(dtype):
    return (
        hasattr(dtype, "metadata")
        and dtype.metadata is not None
        and "mgb_dtype" in dtype.metadata
    )


def get_scale(dtype):
    assert is_quantize(dtype)
    return dtype.metadata["mgb_dtype"]["scale"]


def get_zero_point(dtype):
    assert is_quantize(dtype)
    metadata = dtype.metadata["mgb_dtype"]
    assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm")
    return metadata["zero_point"]


def is_equal(dt0, dt1):
    def _get_zero_point(dtype):
        assert is_quantize(dtype)
        metadata = dtype.metadata["mgb_dtype"]
        return metadata.get("zero_point")

    if is_quantize(dt0) and is_quantize(dt1):
        return get_scale(dt0) == get_scale(dt1) and _get_zero_point(
            dt0
        ) == _get_zero_point(dt1)
    if not (is_quantize(dt0) or is_quantize(dt1)):
        return dt0 == dt1
    return False


def _check_zero_point(zp: int, dtype_str: str):
    qmin = _metadata_dict[dtype_str].qmin
    qmax = _metadata_dict[dtype_str].qmax
    if zp < qmin or zp > qmax:
        raise ValueError(
            "zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str)
        )


def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]):
    r"""
    Get quantized dtype with metadata attribute according to _metadata_dict.

    Note that unsigned dtype must have ``zero_point`` and signed dtype must
    not have ``zero_point``, to be consitent with tensor generated by calling
    compiled function from `CompGraph.compile(inputs, outspec)`.

    :param dtype: a string indicating which dtype to return
    :param scale: a number for scale to store in dtype's metadata
    :param zp: a number for zero_point to store in dtype's metadata
    """
    metadata = _metadata_dict[dtype_str]
    np_dtype_str = metadata.np_dtype_str
    is_unsigned = metadata.is_unsigned
    if is_unsigned:
        if zp is None or int(zp) != zp:
            raise ValueError("zero_point should be an integer")
        zp = int(zp)
        _check_zero_point(zp, dtype_str)
        return np.dtype(
            np_dtype_str,
            metadata={
                "mgb_dtype": {
                    "name": metadata.name,
                    "scale": float(scale),
                    "zero_point": zp,
                }
            },
        )
    else:
        return np.dtype(
            np_dtype_str,
            metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}},
        )


def quint8(scale, zero_point):
    """
    Consturct a quantized unsigned int8 data type with ``scale`` (float) and
    ``zero_point`` (uint8). The real value represented by a quint8 data type is
    float_val = scale * (uint8_val - zero_point)
    """
    return get_quantized_dtype("quint8", scale, zero_point)


def qint8(scale):
    """
    Construct a quantized int8 data type with ``scale`` (float). The real value
    represented by a qint8 data type is float_val = scale * int8_val
    """
    return get_quantized_dtype("qint8", scale, None)


def qint32(scale):
    """
    Construct a quantized int32 data type with ``scale`` (float). The real value
    represented by a qint32 data type is float_val = scale * int32_val
    """
    return get_quantized_dtype("qint32", scale, None)


def quint4(scale, zero_point):
    """
    Consturct a quantized unsigned int4 data type with ``scale`` (float) and
    ``zero_point`` (uint8). The real value represented by a quint4 data type is
    float_val = scale * (uint4_val - zero_point)
    """
    return get_quantized_dtype("quint4", scale, zero_point)


def qint4(scale):
    """
    Construct a quantized int4 data type with ``scale`` (float). The real value
    represented by a qint4 data type is float_val = scale * int4_val
    """
    return get_quantized_dtype("qint4", scale, None)


def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str):
    metadata = _metadata_dict[dtype_str]
    arr_metadata = dtype.metadata["mgb_dtype"]
    if not isinstance(arr, np.ndarray):
        raise ValueError("arr parameter should be instance of np.ndarray")
    if not is_quantize(dtype) or arr_metadata["name"] != metadata.name:
        raise ValueError("dtype parameter should be a {} dtype".format(dtype_str))
    is_unsigned = metadata.is_unsigned
    if is_unsigned:
        scale, zp = (
            arr_metadata["scale"],
            arr_metadata["zero_point"],
        )
        return (
            (np.round(arr / scale) + zp)
            .clip(metadata.qmin, metadata.qmax)
            .astype(dtype)
        )
    else:
        # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype``
        scale = arr_metadata["scale"]
        return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype)


def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str):
    metadata = _metadata_dict[dtype_str]
    arr_metadata = arr.dtype.metadata["mgb_dtype"]
    if not isinstance(arr, np.ndarray):
        raise ValueError("arr parameter should be instance of np.ndarray")
    if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name:
        raise ValueError("arr's dtype should be a {} dtype".format(dtype_str))
    is_unsigned = metadata.is_unsigned
    if is_unsigned:
        scale, zp = (
            arr_metadata["scale"],
            arr_metadata["zero_point"],
        )
        return (arr.astype(np.float32) - zp) * scale
    else:
        # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype``
        scale = arr_metadata["scale"]
        return (arr.astype(np.float32)) * scale


def convert_to_quint8(arr: np.ndarray, q: np.dtype):
    """
    Quantize a float NumPy ndarray into a quint8 one with specified params.

    :param arr: Input ndarray.
    :param q: Target data type, should be a quint8.
    """
    return _convert_to_quantized_dtype(arr, q, "quint8")


def convert_from_quint8(arr: np.ndarray):
    """
    Dequantize a quint8 NumPy ndarray into a float one.

    :param arr: Input ndarray.
    """
    return _convert_from_quantized_dtype(arr, "quint8")


def convert_to_qint8(arr: np.ndarray, q: np.dtype):
    """
    Quantize a float NumPy ndarray into a qint8 one with specified params.

    :param arr: Input ndarray.
    :param q: Target data type, should be a qint8.
    """
    return _convert_to_quantized_dtype(arr, q, "qint8")


def convert_from_qint8(arr: np.ndarray):
    """
    Dequantize a qint8 NumPy ndarray into a float one.

    :param arr: Input ndarray.
    """
    return _convert_from_quantized_dtype(arr, "qint8")


def convert_to_qint32(arr: np.ndarray, q: np.dtype):
    """
    Quantize a float NumPy ndarray into a qint32 one with specified params.

    :param arr: Input ndarray.
    :param q: Target data type, should be a qint8.
    """
    return _convert_to_quantized_dtype(arr, q, "qint32")


def convert_from_qint32(arr):
    """
    Dequantize a qint32 NumPy ndarray into a float one.

    :param arr: Input ndarray.
    """
    return _convert_from_quantized_dtype(arr, "qint32")


def convert_to_quint4(arr: np.ndarray, q: np.dtype):
    """
    Quantize a float NumPy ndarray into a quint4 one with specified params.

    :param arr: Input ndarray.
    :param q: Target data type, should be a quint4.
    """
    return _convert_to_quantized_dtype(arr, q, "quint4")


def convert_from_quint4(arr: np.ndarray):
    """
    Dequantize a quint4 NumPy ndarray into a float one.

    :param arr: Input ndarray.
    """
    return _convert_from_quantized_dtype(arr, "quint4")


def convert_to_qint4(arr: np.ndarray, q: np.dtype):
    """
    Quantize a float NumPy ndarray into a qint4 one with specified params.

    :param arr: Input ndarray.
    :param q: Target data type, should be a qint4.
    """
    return _convert_to_quantized_dtype(arr, q, "qint4")


def convert_from_qint4(arr: np.ndarray):
    """
    Dequantize a qint4 NumPy ndarray into a float one.

    :param arr: Input ndarray.
    """
    return _convert_from_quantized_dtype(arr, "qint4")
